# Result plots 
library(pacman)
p_load(
  data.table, here, ggplot2, fixest, magrittr, dplyr, ggExtra, 
  haven, purrr, sf, tigris, tidyverse, janitor, 
  latex2exp, scales, RColorBrewer, readxl
)
options(tigris_use_cache = TRUE)
WIDTH = 8 #inches
theme_set(
  theme_void(base_size = 14) +
    theme(
      legend.position = 'bottom',
      plot.margin = margin(0,0,0,0),
      legend.key.width = unit(WIDTH/6, 'cm'),
      legend.text = element_text(size = 18)
    )
)

# Loading shape files 
states_sf = 
  states(year = 2020, cb = TRUE) |>
  filter( # Limiting to continental US
    !(STATEFP %in% c("02","15","60","66","69","72","78"))
  ) |>
  clean_names() |> 
  st_transform(crs = 2163)

# Loading results 
welfare_max_cost_neutral_dt = 
  data.table(
    read_xls(here('StructuralCode/Results/LevelsTableTCollPercPolDOptCN.xls'))
  )[!is.na(State)& State != 'all',.(
    stusps = State, 
      wm_sub_base = as.numeric(sub_base_state), 
      wm_sub_opt = as.numeric(sub_opt_state), 
      wm_sub_diff = as.numeric(sub_opt_state) - as.numeric(sub_base_state),
      wm_sub_log_diff = log(as.numeric(sub_opt_state)) - log(as.numeric(sub_base_state)),
      wm_sub_perc_diff_rb = (as.numeric(sub_opt_state) - as.numeric(sub_base_state))/
        as.numeric(sub_base_state),
      wm_sub_perc_diff_ro = (as.numeric(sub_base_state) - as.numeric(sub_opt_state))/
        as.numeric(sub_opt_state),
      wm_sub_ratio = as.numeric(sub_base_state)/as.numeric(sub_opt_state),
      wm_install_per_1000HH_base = as.numeric(`BI_base_stateper1000HH`),
      wm_install_per_1000HH_opt = as.numeric(`BI_opt_stateper1000HH`),
      wm_install_diff = as.numeric(`BI_opt_stateper1000HH`)
                        - as.numeric(`BI_base_stateper1000HH`),
      wm_install_log_diff = log(as.numeric(`BI_opt_stateper1000HH`))
                        - log(as.numeric(`BI_base_stateper1000HH`)),
      wm_install_perc_diff_rb = (as.numeric(`BI_opt_stateper1000HH`)
                        - as.numeric(`BI_base_stateper1000HH`))/
                        as.numeric(`BI_base_stateper1000HH`),
      wm_install_perc_diff_ro = (as.numeric(`BI_base_stateper1000HH`)
                        - as.numeric(`BI_opt_stateper1000HH`))/
                        as.numeric(`BI_opt_stateper1000HH`),                  
      wm_install_ratio = as.numeric(`BI_base_stateper1000HH`)/
                          as.numeric(`BI_opt_stateper1000HH`)
                        
  )]
damage_min_cost_neutral_dt = 
  data.table(
    read_xls(here('StructuralCode/Results/LevelsTableTCollPercPolDOptMD.xls'))
  )[!is.na(State)& State != 'all',
    .(stusps = State, 
      dm_sub_base_state = as.numeric(sub_base_state), 
      dm_sub_opt_state = as.numeric(sub_opt_state), 
      dm_sub_diff = as.numeric(sub_opt_state) - as.numeric(sub_base_state),
      dm_sub_log_diff = log(as.numeric(sub_opt_state)) - log(as.numeric(sub_base_state)),
      dm_sub_perc_diff_rb = (as.numeric(sub_opt_state) - as.numeric(sub_base_state))/
                    as.numeric(sub_base_state),
      dm_sub_perc_diff_ro = (as.numeric(sub_base_state) - as.numeric(sub_opt_state))/
                    as.numeric(sub_opt_state),
      dm_sub_ratio = as.numeric(sub_base_state)/as.numeric(sub_opt_state),
      dm_install_per_1000HH_base = as.numeric(BI_base_stateper1000HH),
      dm_install_per_1000HH_opt = as.numeric(BI_opt_stateper1000HH),
      dm_install_diff = as.numeric(BI_opt_stateper1000HH)-as.numeric(BI_base_stateper1000HH),
      dm_install_log_diff = log(as.numeric(BI_opt_stateper1000HH))
                            -log(as.numeric(BI_base_stateper1000HH)),
      dm_install_perc_diff_rb = (as.numeric(BI_opt_stateper1000HH)-as.numeric(BI_base_stateper1000HH))/
                        as.numeric(BI_base_stateper1000HH),
      dm_install_perc_diff_ro = (as.numeric(BI_base_stateper1000HH)-as.numeric(BI_opt_stateper1000HH))/
                        as.numeric(BI_opt_stateper1000HH),
      dm_install_ratio = as.numeric(BI_base_stateper1000HH)/
                         as.numeric(BI_opt_stateper1000HH)
                        
  )]
# Merging them together 
long_results_dt = 
  merge(
    welfare_max_cost_neutral_dt,
    damage_min_cost_neutral_dt, 
    by = 'stusps',
    all.x = TRUE
  ) |>
  melt(
    id.vars = c('stusps')
  ) %>% .[,.(
    stusps, 
    variable = ifelse(str_detect(variable, 'sub'), 'subsidies','installs'),
    problem = str_sub(variable, 1,2),
    scenario = fcase(
      str_detect(variable, '_base'), 'Baseline', 
      str_detect(variable, '_opt'), 'Optimal', 
      str_detect(variable, 'log_diff'), 'log_diff',
      str_detect(variable, 'perc_diff_rb'), 'perc_diff_rel_base',
      str_detect(variable, 'perc_diff_ro'), 'perc_diff_rel_opt',
      str_detect(variable, 'diff'), 'diff',
      str_detect(variable, 'ratio'), 'ratio'
    ),
    value
  )]
# Creating ranks 
long_results_dt[,
  rank := frank(value), 
  by = .(variable, problem, scenario)
]
state_long_results_sf = 
  merge(
    states_sf, 
    long_results_dt,
    by = 'stusps',
    all.x = TRUE
  )

# First for the welfare maximizing case 
levels_plots = function(
    variable_in, 
    problem_in, 
    filename_in,
    width_in = 8
  ){
  # Setting breaks
  if(variable_in == 'subsidies'){
    name_in = 'Expected state subsidies ($1000)'
    breaks_in = seq(5,25, by = 5)
    labels_in = breaks_in
    limits_in = c(0,30)
    data_sf = state_long_results_sf |> 
        filter(
          problem == problem_in 
          & variable == variable_in
          & str_detect(scenario, 'diff',negate = TRUE)
          & str_detect(scenario, 'ratio',negate = TRUE)
        ) |>
        mutate(value = value)
  }else{
    name_in = 'Installations per 1000 households'
    data_sf = state_long_results_sf |> 
        filter(
          problem == problem_in 
          & variable == variable_in
          & str_detect(scenario, 'diff',negate = TRUE)
          & str_detect(scenario, 'ratio',negate = TRUE)
        ) |>
        mutate(value = log(value)) 
    labels_in = c(2.5,7.5,20,55)
    breaks_in = log(labels_in)
    limits_in = c(min(data_sf$value),max(data_sf$value))
  }
  levels_p = ggplot() +
    geom_sf(
      data = data_sf, 
      aes(fill = value),
      color = 'black'
    ) + 
    scale_fill_gradientn(
      breaks = breaks_in,
      labels = labels_in,
      limits = limits_in,
      colors = brewer.pal(n = 9, name = 'YlGnBu'), #'BrBG'
      name = name_in,
      guide = guide_colorbar(ticks.colour = "white", ticks.linewidth = 1.25)
    ) + 
    facet_grid(cols = vars(scenario))
  ggsave(
    plot = levels_p,
    filename = here(paste0('figures/maps/results/',filename_in,'.jpeg')),
    bg = 'white', 
    width = width_in, height = width_in/2.67, units = 'in'
  )
}

# Running for both cases 
levels_plots(
  problem_in = 'wm',
  filename_in = 'state-subsidy-levels-cost-neutral-welfare-max',
  variable_in = 'subsidies'
)
levels_plots(
  problem_in = 'dm',
  filename_in = 'state-subsidy-levels-cost-neutral-damage-min',
  variable_in = 'subsidies'
)
levels_plots(
  problem_in = 'wm',
  filename_in = 'install-levels-cost-neutral-welfare-max',
  variable_in = 'installs'
)
levels_plots(
  problem_in = 'dm',
  filename_in = 'install-levels-cost-neutral-damage-min',
  variable_in = 'installs'
)

p_load(colorspace, scales)
# Now plotting differences 
diff_plots = function(
    name_in,
    problem_in, 
    variable_in,
    scenario_in,
    filename_in = NULL,
    width_in = 8, 
    print_plot = TRUE
  ){
  # Filtering the data 
  plot_sf = 
    state_long_results_sf |> 
    filter(
      problem == problem_in 
      & scenario == scenario_in
      & variable == variable_in
    ) |>
    arrange(rank)
  # Setting type of label
  if(str_detect(scenario_in, 'perc_diff')){
    scale_fill_in = \(...){
      scale_fill_continuous_diverging(
        ..., 
        name = name_in, 
        palette = 'Green-Brown',
        mid = 0, 
        rev = TRUE, 
        labels =  label_percent(big.mark = ''),
        guide = guide_colorbar(ticks.colour = "white", ticks.linewidth = 1.25)
      )
    }
  }else if(scenario_in == 'diff' & variable_in == 'subsidies'){
    scale_fill_in = \(...){
      scale_fill_continuous_diverging(
        ..., 
        name = name_in, 
        palette = 'Green-Brown',
        mid = 0, 
        rev = TRUE, 
        labels = label_dollar(suffix = 'K', largest_with_cents = 0),
        guide = guide_colorbar(ticks.colour = "white", ticks.linewidth = 1.25)
      )
    }
  }else if(scenario_in == 'ratio'){
    scale_fill_in = \(...){
      scale_fill_gradientn(
        colors = brewer.pal(n = 9, name = 'YlGnBu'), 
        trans = 'log10',
        label = label_percent(),
        name = name_in,
        limits = c(0.22, 6),
        guide = guide_colorbar(ticks.colour = "white", ticks.linewidth = 1.25)
      )
    }
  }else{
    scale_fill_in = \(...){
      scale_fill_continuous_diverging(
        ..., 
        name = name_in, 
        palette = 'Green-Brown',
        mid = 0, 
        rev = TRUE, 
        labels = scales::number,
        guide = guide_colorbar(ticks.colour = "white", ticks.linewidth = 1.25)
      )
    }
  }
  # Plotting
  diff_p = 
   ggplot() + 
    geom_sf(data = plot_sf, aes(fill = value), color = 'black') + 
    scale_fill_in() + 
    theme(legend.key.width= unit(width_in/4, 'cm'))
  # Printing the plot if specified
  if(print_plot == TRUE) print(diff_p)
  # Saving the plot if specified
  if(!is.null(filename_in)){
    ggsave(
      plot = diff_p,
      filename = here(paste0('figures/maps/results/',filename_in,'.jpeg')),
      bg = 'white', 
      width = width_in, height = width_in/1.4, units = 'in'
    )
  }
}

# Testing
#diff_plots(
#  name_in = '',
#  problem_in = 'wm',
#  variable_in = 'subsidies',
#  scenario_in = 'diff',  
#)


# Main paper result plots -----------------------------------------------------
# Subsidy difference, install ratio (baseline/optimal)
diff_plots(
  name_in = '',#'Difference in expected state subsidies ($1000)',
  problem_in = 'wm',
  filename_in = 'state-subsidy-diff-cost-neutral-welfare-max',
  variable_in = 'subsidies',
  scenario_in = 'diff'
)
diff_plots(
  name_in = '',#'Difference in expected state subsidies ($1000)',
  problem_in = 'dm',
  filename_in = 'state-subsidy-diff-cost-neutral-damage-min',
  variable_in = 'subsidies',
  scenario_in = 'diff'
)
diff_plots(
  name_in = '',
  problem_in = 'wm',
  filename_in = 'install-ratio-cost-neutral-welfare-max',
  variable_in = 'installs',
  scenario_in = 'ratio'
)
diff_plots(
  name_in = '',
  problem_in = 'dm',
  filename_in = 'install-ratio-cost-neutral-damage-min',
  variable_in = 'installs',
  scenario_in = 'ratio'
)



# Appendix result plots -------------------------------------------------------
diff_plots(
  name_in = '',#,'Percent difference in expected state subsidies',
  problem_in = 'wm',
  filename_in = 'state-subsidy-perc-diff-rel-base-cost-neutral-welfare-max',
  variable_in = 'subsidies',
  scenario_in = 'perc_diff_rel_base'
)
diff_plots(
  name_in = '',#'Percent difference in expected state subsidies',
  problem_in = 'dm',
  filename_in = 'state-subsidy-perc-diff-rel-base-cost-neutral-damage-min',
  variable_in = 'subsidies',
  scenario_in = 'perc_diff_rel_base'
)
diff_plots(
  name_in = '',#,'Percent difference in expected state subsidies',
  problem_in = 'wm',
  filename_in = 'state-subsidy-perc-diff-rel-opt-cost-neutral-welfare-max',
  variable_in = 'subsidies',
  scenario_in = 'perc_diff_rel_opt'
)
diff_plots(
  name_in = '',#'Percent difference in expected state subsidies',
  problem_in = 'dm',
  filename_in = 'state-subsidy-perc-diff-rel-opt-cost-neutral-damage-min',
  variable_in = 'subsidies',
  scenario_in = 'perc_diff_rel_opt'
)

# Now for change in installations 
diff_plots(
  name_in = '',#'Difference in installations per 1000 households',
  problem_in = 'wm',
  filename_in = 'install-diff-cost-neutral-welfare-max',
  variable_in = 'installs',
  scenario_in = 'diff'
)
diff_plots(
  name_in = '',#'Difference in installations per 1000 households',
  problem_in = 'dm',
  filename_in = 'install-diff-cost-neutral-damage-min',
  variable_in = 'installs',
  scenario_in = 'diff'
)
diff_plots(
  name_in = '',#'Difference in log installations per 1000 households',
  problem_in = 'wm',
  filename_in = 'install-log-diff-cost-neutral-welfare-max',
  variable_in = 'installs',
  scenario_in = 'log_diff'
)
diff_plots(
  name_in = '',#'Difference in log installations per 1000 households',
  problem_in = 'dm',
  filename_in = 'install-log-diff-cost-neutral-damage-min',
  variable_in = 'installs',
  scenario_in = 'log_diff'
)
diff_plots(
  name_in = '',#'Percent difference in installations per 1000 households',
  problem_in = 'wm',
  filename_in = 'install-perc-diff-rel-base-cost-neutral-welfare-max',
  variable_in = 'installs',
  scenario_in = 'perc_diff_rel_base'
)
diff_plots(
  name_in = '',#'Percent difference installations per 1000 households',
  problem_in = 'dm',
  filename_in = 'install-perc-diff-rel-base-cost-neutral-damage-min',
  variable_in = 'installs',
  scenario_in = 'perc_diff_rel_base'
)
diff_plots(
  name_in = '',#'Percent difference in installations per 1000 households',
  problem_in = 'wm',
  filename_in = 'install-perc-diff-rel-opt-cost-neutral-welfare-max',
  variable_in = 'installs',
  scenario_in = 'perc_diff_rel_opt'
)
diff_plots(
  name_in = '',#'Percent difference installations per 1000 households',
  problem_in = 'dm',
  filename_in = 'install-perc-diff-rel-opt-cost-neutral-damage-min',
  variable_in = 'installs',
  scenario_in = 'perc_diff_rel_opt'
)


# Plotting just optimal subsidies
width_in = 8
censor_lim = c(5,20)
breaks_in = c(5,10,15,20)
optimal_subsidies_wm_p = ggplot() +
    geom_sf(
      data = state_long_results_sf |> 
        filter(
          problem == 'wm' 
          & variable == 'subsidies'
          & scenario == 'Optimal'
        ) |>
        mutate(
          value = fcase(
            value < censor_lim[1], censor_lim[1], 
            value >= censor_lim[1] & value <= censor_lim[2], value, 
            value > censor_lim[2], censor_lim[2]   
          )
        ), 
      aes(fill = value),
      color = 'black'
    ) + 
    #scale_fill_brewer(
    #  palette = 'YlGnBu'
    #) +
    scale_fill_gradientn(
      name = '',
      colors = brewer.pal(n = 9, name = 'YlGnBu'), #'BrBG'
      labels = paste0(
        c('<$','$','$','>$'),
        breaks_in,
        rep('K',4)
      ),
      breaks = breaks_in,
      limits = censor_lim,
      guide = guide_colorbar(ticks.colour = "white", ticks.linewidth = 1.25)
    )+ 
    theme(legend.key.width= unit(width_in/4, 'cm'))
  optimal_subsidies_wm_p
  ggsave(
    plot = optimal_subsidies_wm_p,
    filename = here('figures/maps/results/state-subsidy-opt-cost-neutral-welfare-max.jpeg'),
    bg = 'white', 
    width = width_in, height = width_in/1.4, units = 'in'
  )
  optimal_subsidies_dm_p = ggplot() +
    geom_sf(
      data = state_long_results_sf |> 
        filter(
          problem == 'dm' 
          & variable == 'subsidies'
          & scenario == 'Optimal'
        ) |>
        mutate(
          value = fcase(
            value < censor_lim[1], censor_lim[1], 
            value >= censor_lim[1] & value <= censor_lim[2], value, 
            value > censor_lim[2], censor_lim[2]   
          )
        ), 
      aes(fill = value),
      color = 'black'
    ) + 
    scale_fill_gradientn(
      labels = paste0(
        c('<$','$','$','>$'),
        breaks_in,
        rep('K',4)
      ),
      breaks = breaks_in,
      limits = censor_lim,
      colors = brewer.pal(n = 9, name = 'YlGnBu'), #'BrBG'
      name = '',
      guide = guide_colorbar(ticks.colour = "white", ticks.linewidth = 1.25)
    )+ 
    theme(legend.key.width= unit(width_in/4, 'cm'))
  optimal_subsidies_dm_p
  ggsave(
    plot = optimal_subsidies_dm_p,
    filename = here('figures/maps/results/state-subsidy-opt-cost-neutral-damage-min.jpeg'),
    bg = 'white', 
    width = width_in, height = width_in/1.4, units = 'in'
  )
# A different plot for the ratio
censor_lim_ratio = c(0.25, 6)
breaks_in_ratio = c(0.25,0.75,2,6)
wm_subsidy_ratio = 
  ggplot() +
    geom_sf(
      data = state_long_results_sf |> 
        filter(
          problem == 'wm' 
          & variable == 'subsidies'
          & str_detect(scenario, 'ratio')
        )|>
        mutate(
          value = fcase(
            value < censor_lim_ratio[1], censor_lim_ratio[1], 
            value >= censor_lim_ratio[1] & value <= censor_lim_ratio[2], value, 
            value > censor_lim_ratio[2], censor_lim_ratio[2]   
          )
        ), 
      aes(fill = log(value)),
      color = 'black'
    ) + 
    scale_fill_gradientn(
      name = '',
      colors = brewer.pal(n = 9, name = 'YlGnBu'), #'BrBG'
      limits = log(censor_lim_ratio),
      breaks = log(breaks_in_ratio),
      labels = paste0(c('<','','','>'), percent(breaks_in_ratio)),
      guide = guide_colorbar(ticks.colour = "white", ticks.linewidth = 1.25)
    ) + 
    theme(legend.key.width= unit(width_in/4, 'cm'))
    wm_subsidy_ratio
    ggsave(
      plot = wm_subsidy_ratio,
      filename = here(paste0('figures/maps/results/state-subsidy-base-opt-ratio-cost-neutral-welfare-max.jpeg')),
      bg = 'white', 
      width = width_in, height = width_in/1.4, units = 'in'
    )
dm_subsidy_ratio = 
  ggplot() +
    geom_sf(
      data = state_long_results_sf |> 
        filter(
          problem == 'dm' 
          & variable == 'subsidies'
          & str_detect(scenario, 'ratio')
        )|>
        mutate(
          value = fcase(
            value < censor_lim_ratio[1], censor_lim_ratio[1], 
            value >= censor_lim_ratio[1] & value <= censor_lim_ratio[2], value, 
            value > censor_lim_ratio[2], censor_lim_ratio[2]   
          )
        ), 
      aes(fill = log(value)),
      color = 'black'
    ) + 
    scale_fill_gradientn(
      name = '',
      colors = brewer.pal(n = 9, name = 'YlGnBu'), #'BrBG'
      limits = log(censor_lim_ratio),
      breaks = log(breaks_in_ratio),
      labels = paste0(c('','','>'), percent(breaks_in_ratio)),
      guide = guide_colorbar(ticks.colour = "white", ticks.linewidth = 1.25)
    ) + 
    theme(legend.key.width= unit(width_in/4, 'cm'))
    ggsave(
      plot = dm_subsidy_ratio,
      filename = here(paste0('figures/maps/results/state-subsidy-base-opt-ratio-cost-neutral-damage-min.jpeg')),
      bg = 'white', 
      width = width_in, height = width_in/1.4, units = 'in'
    )

# Marginal subsidy increases  
p_load(readxl)

marg_sub_result_sf = 
  cbind(
    states_sf |> 
      filter(stusps != 'DC') |>
      arrange(name),
    read_xls(
      here('StructuralCode/Results/MargSubsidy.xls'),
      col_types = 'numeric'
    )
  )
  #inner_join(
  #  states_sf,
  #  read_xls(
  #    here('StructuralCode/Results/MargSubsidy.xls'),
  #    col_names = c('unit_sub', 'cost_sub', 'kwh_sub', 'stusps',paste0('x',1:7))
  #  ),
  #  by = 'stusps'
  #) 
marg_sub_kwh_p = 
  ggplot() +
    geom_sf(
      data = marg_sub_result_sf, 
      aes(fill = marg_sub_kwh),
      color = 'black'
    ) + 
    scale_fill_gradientn(
      name = '',
      colors = brewer.pal(n = 9, name = 'YlGnBu'), #'BrBG'
      labels = scales::label_dollar(), 
      breaks = c(
        min(marg_sub_result_sf$marg_sub_kwh),
        0.4, 0.6,
        max(marg_sub_result_sf$marg_sub_kwh)
      ),
      guide = guide_colorbar(ticks.colour = "white", ticks.linewidth = 1.25)
    ) + 
    theme(legend.key.width= unit(width_in/4, 'cm'))
    ggsave(
      plot = marg_sub_kwh_p,
      filename = here(paste0('figures/maps/results/marg-subsidy-kwh.jpeg')),
      bg = 'white', 
      width = width_in, height = width_in/1.4, units = 'in'
    )


